package com.yeziji.utils;

import cn.hutool.extra.spring.SpringUtil;
import com.yeziji.common.base.IUserDetailsBase;
import com.yeziji.common.base.UserServletInfoBase;
import com.yeziji.common.context.OnlineContext;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;

/**
 * 线程池管理工具类
 *
 * @author hwy
 * @see com.yeziji.config.ThreadPoolConfig
 * @since 2023/07/29 13:32
 **/
public class ThreadPoolUtils {
    public static final String CUSTOM = "customExecutor";
    public static final String IO = "ioTaskExecutor";
    public static final String CPU = "cpuTaskExecutor";

    public static ThreadPoolTaskExecutor getIoThreadPool() {
        return getThreadPool(IO);
    }

    public static ThreadPoolTaskExecutor getCpuThreadPool() {
        return getThreadPool(CPU);
    }

    public static ThreadPoolTaskExecutor getCustomThreadPool() {
        return getThreadPool(CUSTOM);
    }

    public static ThreadPoolTaskExecutor getThreadPool(final String poolName) {
        return SpringUtil.getBean(poolName);
    }

    public static <U> U execute(Supplier<U> supplier, int waitSecond, ThreadPoolTaskExecutor threadPoolTaskExecutor) {
        // 获取执行前的信息
        IUserDetailsBase onlineUserOnlineInfo;
        if (OnlineContext.getOnlineUserOnlineInfo().getUserId() != null) {
            onlineUserOnlineInfo = OnlineContext.getOnlineUserOnlineInfo();
        } else {
            onlineUserOnlineInfo = null;
        }
        UserServletInfoBase onlineUserServletInfo = OnlineContext.getOnlineUserServletInfo();
        CompletableFuture<U> future = CompletableFuture.supplyAsync(() -> {
            // 继承 online 上下文
            try {
                OnlineContext.setOnlineUserDetails(onlineUserOnlineInfo);
                OnlineContext.setOnlineUserServletInfo(onlineUserServletInfo);
                return supplier.get();
            } finally {
                OnlineContext.clear();
            }
        }, threadPoolTaskExecutor);
        try {
            return future.get(waitSecond, TimeUnit.SECONDS);
        } catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new RuntimeException(e);
        }
    }

    public static void execute(Runnable runnable, ThreadPoolTaskExecutor threadPoolTaskExecutor) {
        // 获取执行前的信息
        IUserDetailsBase onlineUserOnlineInfo;
        if (OnlineContext.getOnlineUserOnlineInfo().getUserId() != null) {
            onlineUserOnlineInfo = OnlineContext.getOnlineUserOnlineInfo();
        } else {
            onlineUserOnlineInfo = null;
        }
        UserServletInfoBase onlineUserServletInfo = OnlineContext.getOnlineUserServletInfo();
        CompletableFuture.runAsync(() -> {
            // 继承 online 上下文
            try {
                OnlineContext.setOnlineUserDetails(onlineUserOnlineInfo);
                OnlineContext.setOnlineUserServletInfo(onlineUserServletInfo);
                runnable.run();
            } finally {
                OnlineContext.clear();
            }
        }, threadPoolTaskExecutor);
    }

    // ---- IO 线程池
    public static <U> U ioExecute(Supplier<U> supplier, int waitSecond) {
        return execute(supplier, waitSecond, getIoThreadPool());
    }

    public static void ioExecute(Runnable runnable) {
        execute(runnable, getIoThreadPool());
    }


    // ---- CPU 线程池
    public static <U> U cpuExecute(Supplier<U> supplier, int waitSecond) {
        return execute(supplier, waitSecond, getCpuThreadPool());
    }

    public static void cpuExecute(Runnable runnable) {
        execute(runnable, getCpuThreadPool());
    }

    // ---- 自定义线程池
    public static <U> U customExecute(Supplier<U> supplier, int waitSecond) {
        return execute(supplier, waitSecond, getCustomThreadPool());
    }

    public static void customExecute(Runnable runnable) {
        execute(runnable, getCustomThreadPool());
    }
}
